#!/usr/bin/env python3

import os
import time
from pathlib import Path
from argparse import Namespace

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from timm.loss import SoftTargetCrossEntropy
from timm.scheduler import CosineLRScheduler
from timm.data import Mixup, create_loader
from timm.utils import accuracy
import lightning as L

from modern_hopfield_attention.model import CustomVisionTransformer, UniversalViT
from modern_hopfield_attention.data import create_vision_dataset
from modern_hopfield_attention.functional import token_cossim


class LitViT(L.LightningModule):
    def __init__(self, args: Namespace) -> None:
        super().__init__()
        # save
        self.save_hyperparameters(args)

        # args
        self.save_dir = args.save_dir

        # model
        if args.universal:
            self.model = UniversalViT(
                img_size=args.img_size,
                patch_size=args.patch_size,
                in_chans=args.in_chans,
                num_classes=args.num_classes,
                embed_dim=args.embed_dim,
                depth=args.depth,
                num_heads=args.num_heads,
                class_token=args.class_token,
                drop_path_rate=args.drop_path_rate,
            )

        else:
            self.model = CustomVisionTransformer(
                img_size=args.img_size,
                patch_size=args.patch_size,
                in_chans=args.in_chans,
                num_classes=args.num_classes,
                embed_dim=args.embed_dim,
                depth=args.depth,
                num_heads=args.num_heads,
                class_token=args.class_token,
                drop_path_rate=args.drop_path_rate,
            )

        # loss
        self.train_loss_fn = SoftTargetCrossEntropy()
        self.valid_loss_fn = nn.CrossEntropyLoss()

        # optimizer&scheduler
        self.optimizer = optim.AdamW(
            self.parameters(),
            lr=args.learning_rate,
            eps=args.eps,
            betas=args.betas,
            weight_decay=args.weight_decay,
        )
        self.scheduler = CosineLRScheduler(
            optimizer=self.optimizer,
            t_initial=args.t_initial,
            lr_min=args.lr_min,
            warmup_t=args.warmup_t,
            warmup_lr_init=args.warmup_lr_init,
            warmup_prefix=True,
        )
        # mixup
        self.mixip_fn = Mixup(
            mixup_alpha=args.mixup,
            cutmix_alpha=args.cutmix,
            label_smoothing=args.label_smoothing,
            num_classes=args.num_classes,
        )

        # dataloader
        train_set = create_vision_dataset(
            args.dataset_type,
            args.dataset_dir,
            split="train",
            is_training=True,
            download=True,
            batch_size=args.batch_size,
            repeats=0,
        )
        self.train_loader = create_loader(
            train_set,
            input_size=(args.in_chans, args.img_size, args.img_size),
            batch_size=args.batch_size,
            is_training=True,
            # random erasing
            re_prob=args.re_prob,
            re_mode=args.re_mode,
            re_count=args.re_count,
            re_split=args.re_split,
            # crop
            crop_pct=args.crop_pct,
            # rand augument
            auto_augment=args.auto_augment,
            use_prefetcher=False,
            device="cpu",
            num_workers=args.num_workers,
        )
        valid_set = create_vision_dataset(
            args.dataset_type,
            args.dataset_dir,
            split="validation",
            is_training=False,
            download=True,
            batch_size=args.batch_size,
            repeats=0,
        )
        self.valid_loader = create_loader(
            valid_set,
            input_size=(args.in_chans, args.img_size, args.img_size),
            batch_size=32,
            is_training=False,
            # crop
            crop_pct=args.crop_pct,
            use_prefetcher=False,
            device="cpu",
            num_workers=args.num_workers,
        )

    def on_fit_start(self) -> None:
        self.model.register_hooks()

    def on_train_batch_start(self, batch, batch_idx) -> None:
        self.model.clear_hooks()

    def training_step(
        self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> torch.Tensor:
        x_plain, y_plain = batch
        x_plain, y_plain = x_plain.to(self.device), y_plain.to(self.device)
        B = x_plain.size(0)

        # mixup
        x, y = self.mixip_fn(x_plain, y_plain)

        y_hat = self.model(x)
        loss = self.train_loss_fn(y_hat, y)

        top1, top5, top10 = accuracy(y_hat, y_plain, topk=(1, 5, 10))

        # log
        self.log(
            f"train/loss",
            value=loss.item(),
            batch_size=B,
            on_step=False,
            on_epoch=True,
            sync_dist=True,
        )
        self.log(
            f"train/top1",
            value=top1.item(),
            on_step=False,
            on_epoch=True,
            sync_dist=True,
        )
        self.log(
            f"train/top5",
            value=top5.item(),
            on_step=False,
            on_epoch=True,
            sync_dist=True,
        )
        self.log(
            f"train/top10",
            value=top10.item(),
            on_step=False,
            on_epoch=True,
            sync_dist=True,
        )

        return loss

    def on_validation_start(self):
        self.model.clear_hooks()

    def validation_step(
        self,
        batch: tuple[torch.Tensor, torch.Tensor],
        batch_idx: int,
    ) -> None:
        x, y = batch
        x, y = x.to(self.device), y.to(self.device)
        B = x.size(0)

        start_throughput = time.time()
        y_hat = self.model(x)
        loss = self.valid_loss_fn(y_hat, y)
        throughput_time = time.time() - start_throughput

        self.log(
            "valid/throughput",
            value=throughput_time,
            batch_size=B,
            on_step=False,
            on_epoch=True,
            sync_dist=True,
        )

        top1, top5, top10 = accuracy(y_hat, y, topk=(1, 5, 10))

        # hook
        if batch_idx == 0 and self.global_rank == 0:
            hook_input = torch.stack(self.model.hook_input, dim=1)
            # save
            os.makedirs(
                Path(self.save_dir) / f"hook/epoch{self.current_epoch:03}",
                0o777,
                exist_ok=True,
            )
            torch.save(
                hook_input,
                Path(self.save_dir)
                / f"hook/epoch{self.current_epoch:03}/hook_input.pt",
            )

            similarity = token_cossim(hook_input, vs_clstoken=True)
            modes = similarity.view(similarity.size(1), -1).mode()[0]

            for idx, mode in enumerate(modes):
                self.log(
                    f"valid/cls_uni(layer={idx:02})",
                    value=mode,
                    on_step=False,
                    on_epoch=True,
                    sync_dist=True,
                    rank_zero_only=True,
                )

        self.log(
            f"valid/loss",
            value=loss.item(),
            batch_size=B,
            on_step=False,
            on_epoch=True,
            sync_dist=True,
        )
        self.log(
            f"valid/top1",
            value=top1.item(),
            on_step=False,
            on_epoch=True,
            sync_dist=True,
        )
        self.log(
            f"valid/top5",
            value=top5.item(),
            on_step=False,
            on_epoch=True,
            sync_dist=True,
        )
        self.log(
            f"valid/top10",
            value=top10.item(),
            on_step=False,
            on_epoch=True,
            sync_dist=True,
        )

        return

    def configure_optimizers(self):  # -> tuple[list[AdamW], list[dict[str, Any]]]:

        return [self.optimizer], [{"scheduler": self.scheduler, "interval": "epoch"}]

    def lr_scheduler_step(self, scheduler, metric) -> None:
        scheduler.step(epoch=self.current_epoch)

    def train_dataloader(self) -> DataLoader:
        return self.train_loader

    def val_dataloader(self) -> DataLoader:
        return self.valid_loader
